{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/qiangzibro/stacked_capsule_autoencoders.pytorch\n"
     ]
    }
   ],
   "source": [
    "%cd /home/qiangzibro/stacked_capsule_autoencoders.pytorch/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from model.models.ccae import CCAE\n",
    "from data_loader.ccae_dataloader import CCAE_Dataloader\n",
    "from model.loss import ccae_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "B = 4\n",
    "k = 7  # number of input set\n",
    "dim_input = 2 # 2 for 2D point\n",
    "dim_speical_features = 16\n",
    "n_votes = 4\n",
    "n_objects = 3\n",
    "\n",
    "learning_rate = 1e-3\n",
    "num_epochs = 15\n",
    "# Device configuration\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "data_loader= CCAE_Dataloader(batch_size=B)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = CCAE(dim_input, n_objects, dim_speical_features, n_votes).to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD4CAYAAADvsV2wAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAAXdklEQVR4nO3dfZBd9X3f8ffn7pMe0eNaEiD0EAtjgQm21xpsiomNHB6cIJLYsUjdiBRXjl132kk8U2U043rIdIrtSdykobVV4kZ2Z4xt1QRlLAZLApc0RZSlAYMQQisBRkLSrtYg9IBW2t1v/7hn6WV1V/twzu6V7u/zmtnZ8/C75/fVj8tnz/7O2XsUEZiZWf0r1boAMzObGA58M7NEOPDNzBLhwDczS4QD38wsEY21LmAoc+fOjcWLF9e6DDOzC8pTTz11JCJaq+07bwN/8eLFtLe317oMM7MLiqRXhtrnKR0zs0Q48M3MEuHANzNLhAPfzCwRdRf4p3v76Tzew7Ge3lqXUpeO9/TSebyHnt6+WpdiZqN03t6lM1oRwXMH32RX5zFKEv0Bs6c08dGlc2lurLufaxPuTF8/f/9SN0eO91CS6Ivg8tZpXHPxDCTVujwzG4G6ScJX33iLXV3H6Qs40x/0RdB98jT/++XuWpdWF574xet0He95e3z7A/Z0nWBv94lal2ZmI1Q3gb+r8xh9/e/8qOf+gMPHezh1xtMPeZzp6+fA0bcYNLz0RbC783htijKzUaubwO/p7a+6vSRxuq/6PhuZ3sFJX6HHY2t2waibwF9w0SSqzSSXJKa11M2lipqY1FiiZYjrIPOnt0xwNWY2VnUT+FfOv4jmxhKlitRvkGhbOJOSLyrmIokPLZxFQ8U4lgRNDeLqBTNqWJmZjUbdnPpOaWrg1ivm8ULXcQ4fO8XU5kbe+67pzJnaXOvS6sIlMyaz8vJWXug8xrGeXt41tYX3zJvOlKaGWpdmZiNUN4EPMKmpgWsungH4rHM8zJ7SzEcWz6l1GWY2RnUzpWNmZudWSOBLulnSbkkdktado93vSApJbUX0a2ZmI5c78CU1APcCtwDLgTskLa/Sbjrwr4En8vZpZmajV8QZ/gqgIyL2RcRp4H5gVZV2fwp8DThVQJ9mZjZKRQT+JcCrFev7s21vk/QBYGFE/KSA/szMbAzG/aKtpBLw58Afj6DtWkntktq7urrGuzQzs6QUEfgHgIUV65dm2wZMB64CfibpZeBaYHO1C7cRsSEi2iKirbW16jN4zcxsjIoI/CeBZZKWSGoGVgObB3ZGxNGImBsRiyNiMbADuC0i/IRyM7MJlDvwI6IX+BLwMLAL+GFE7JR0t6Tb8h7fzMyKUchf2kbEFmDLoG1fGaLtrxXRp5lZPTl26gzfeHg3Dz7zGv0R3HrVAtbdfAWzCvx4mLr6aAUzswtRf3/wmQ076Og8xum+8seR//j/7mfHS91s/Tc3FPbUPn+0gplZjf2vjiO80n3i7bCH8pPljhzr4afPHyqsHwe+mVmN7Tr0ZtWHOJ043cfO194srB8HvplZjS2aM5WWprPjeEpzA0tbpxbWjwPfzKzGbrziXcyY1ERDxbOaSoLJTQ188n0LCuvHgW9mVmNNDSV+/IXr+CfLWmkoiYaSWLF4Ng988TqmNBd3b43v0jEzOw/MnzGJjX+wgp7ePiLKD3QqmgPfzOw80tI4fo8N9ZSOmVkiHPhmZolw4JuZJcKBb2aWCAe+mVkiHPhmZolw4JuZJcKBb2aWCAe+mVkiHPhmZokoJPAl3Sxpt6QOSeuq7P8jSc9L+rmk7ZIWFdGvmZmNXO7Al9QA3AvcAiwH7pC0fFCzfwTaIuJqYBPw9bz9mpnZ6BRxhr8C6IiIfRFxGrgfWFXZICIejYiT2eoO4NIC+jUzs1EoIvAvAV6tWN+fbRvKXcBDBfRrZmajMKEfjyzps0AbcMMQ+9cCawEuu+yyCazMzKz+FXGGfwBYWLF+abbtHSStBNYDt0VET7UDRcSGiGiLiLbW1tYCSjMzswFFBP6TwDJJSyQ1A6uBzZUNJL0f+DblsO8soE8zMxul3IEfEb3Al4CHgV3ADyNip6S7Jd2WNfsGMA34kaSnJW0e4nBmZjZOCpnDj4gtwJZB275SsbyyiH7MzGzs/Je2ZmaJcOCbmSXCgW9mlggHvplZIhz4ZmaJcOCbmSXCgW9mlggHvplZIhz4ZmaJcOCbmSXCgW9mlggHvplZIhz4ZmaJcOCbmSXCgW9mlggHvplZIhz4ZmaJcOCbmSWikMCXdLOk3ZI6JK2rsr9F0g+y/U9IWlxEv2ZmNnK5A19SA3AvcAuwHLhD0vJBze4CXo+IdwPfBL6Wt18zMxudIs7wVwAdEbEvIk4D9wOrBrVZBWzMljcBN0pSAX2bmdkIFRH4lwCvVqzvz7ZVbRMRvcBRYM7gA0laK6ldUntXV1cBpZmZ2YDz6qJtRGyIiLaIaGttba11OWZmdaWIwD8ALKxYvzTbVrWNpEZgBtBdQN9mZjZCRQT+k8AySUskNQOrgc2D2mwG1mTLnwIeiYgooO+zHH3rDPu6T3DwzVP0j08XyYoIDh07xd7uE7z+1ulal2Nmo9SY9wAR0SvpS8DDQAPwnYjYKeluoD0iNgN/DXxPUgfwS8o/FArVH8HjL/+SA0ffYuB6cHNDiRuXtTKtJfc/M3knz/Sx/cVOTvX2M/BzdN70Fq5fOoeSr7+bXRAKScKI2AJsGbTtKxXLp4BPF9HXUPYeOcGBo6foCxhIpL7+Pv7h5W5ues+88ew6CY+/3M2J031U/s50+NgpXug8xvJ5F9WsLjMbufPqom0ee44cp2/QFE4Ab7x1hpOn+2pTVJ043dvPkROnGTxB1hfQceRETWoys9Grm8Dv668+Xy901g8CG51zXQsZatzN7PxTN4F/2awplKpMJbc0lpjW3DDxBdWRSU0NTGs+e/avJFg4c3INKjKzsaibwF8+bzpTmxtpyFK/JGgsiQ8vno3/qDe/axfNprGkt3+oNpbE5KYG3rfA8/dmF4q6uX2lqaHELVfM4xevn+Tw8R6mNTeydM5UpvjsvhBzpjbzG8vns6/7BMd6emmd1sKiWZNpLNXNOYNZ3aubwAdoKIklc6ayZM7UWpdSlyY3NXDlfJ/Rm12ofHpmZpYIB76ZWSIc+GZmiXDgm5klwoFvZpYIB76ZWSIc+GZmiXDgm5klwoFvZpYIB76ZWSIc+GZmiXDgm5klIlfgS5otaaukPdn3WVXaXCPpcUk7Jf1c0mfy9GlmZmOT9wx/HbA9IpYB27P1wU4Cvx8RVwI3A/9R0syc/ZqZ2SjlDfxVwMZseSNw++AGEfFiROzJll8DOoHWnP2amdko5Q38eRFxMFs+BMw7V2NJK4BmYO8Q+9dKapfU3tXVlbM0MzOrNOwDUCRtA+ZX2bW+ciUiQtKQT7SWtAD4HrAmIvqrtYmIDcAGgLa2Nj8d28ysQMMGfkSsHGqfpMOSFkTEwSzQO4dodxHwE2B9ROwYc7VmZjZmead0NgNrsuU1wIODG0hqBh4AvhsRm3L2Z2ZmY5Q38O8BPiFpD7AyW0dSm6T7sja/C3wUuFPS09nXNTn7NTOzUVLE+TlV3tbWFu3t7bUuw8zsgiLpqYhoq7bPf2lrZpYIB76ZWSIc+GZmiXDgm5klwoFvZpYIB76ZWSIc+GZmiXDgm5klYtjP0rH6t//1k3z7sX089crr/ErrVD5/w69w1cUzal2WmRXMgZ+4vV3HWXXvP3DqTB+9/cELh95k265OvvXZD3LD5X5sgVk98ZRO4u556AVOnO6lt7/8ERv9AW+d6WP93z7L+fqxG2Y2Ng78xD3xUjfVcv3wm6d481TvxBdkZuPGgZ+4GZObqm4vSUxq8tvDrJ74/+jE/YvrlzK5qeEd21oaS9z2qxfT0tgwxKvM7ELkwE/cP7t2Eb+34jJaGktMn9RIS2OJGy5v5e7brqp1aWZWMH8evgHwxsnT7O06wSUzJzN/xqRal2NmY3Suz8P3bZkGwMwpzXxwUXOtyzCzceQpHTOzROQKfEmzJW2VtCf7PuscbS+StF/SX+Xp08zMxibvGf46YHtELAO2Z+tD+VPgsZz9mZnZGOUN/FXAxmx5I3B7tUaSPgjMA36asz8zMxujvIE/LyIOZsuHKIf6O0gqAX8GfHm4g0laK6ldUntXV1fO0szMrNKwd+lI2gbMr7JrfeVKRISkavd4fhHYEhH7JZ2zr4jYAGyA8m2Zw9VmZmYjN2zgR8TKofZJOixpQUQclLQA6KzS7MPA9ZK+CEwDmiUdj4hzzfebmVnB8t6HvxlYA9yTfX9wcIOI+KcDy5LuBNoc9mZmEy/vHP49wCck7QFWZutIapN0X97izMysOP5oBTOzOnKuj1bwX9qamSXCgW9mlggHvplZIhz4ZmaJcOCbmSXCgW9mlggHvplZIhz4ZmaJcOCbmSXCgW9mlggHvplZIhz4ZmaJcOCbmSXCgW9mlggHvplZIhz4ZmaJcOCbmSXCgW9mlohcgS9ptqStkvZk32cN0e4yST+VtEvS85IW5+nXzMxGL+8Z/jpge0QsA7Zn69V8F/hGRLwXWAF05uzXzMxGKW/grwI2ZssbgdsHN5C0HGiMiK0AEXE8Ik7m7NfMzEYpb+DPi4iD2fIhYF6VNpcDb0j6saR/lPQNSQ3VDiZpraR2Se1dXV05SzMzs0qNwzWQtA2YX2XX+sqViAhJMUQf1wPvB34B/AC4E/jrwQ0jYgOwAaCtra3asczMbIyGDfyIWDnUPkmHJS2IiIOSFlB9bn4/8HRE7Mte87fAtVQJfDMzGz95p3Q2A2uy5TXAg1XaPAnMlNSarX8ceD5nv2ZmNkp5A/8e4BOS9gArs3UktUm6DyAi+oAvA9slPQsI+K85+zUzs1EadkrnXCKiG7ixyvZ24HMV61uBq/P0ZWZm+fgvbc3MEuHANzNLhAPfzCwRDnwzs0Q48M3MEuHANzNLhAPfzCwRDnwzs0Q48M3MEuHANzNLhAPfzCwRDnwzs0Q48M3MEuHANzNLhAPfzCwRDnwzs0Q48M3MEuHANzNLRK7AlzRb0lZJe7Lvs4Zo93VJOyXtkvSXkpSnXzMzG728Z/jrgO0RsQzYnq2/g6SPANdRfqbtVcCHgBty9mtmZqOUN/BXARuz5Y3A7VXaBDAJaAZagCbgcM5+zcxslPIG/ryIOJgtHwLmDW4QEY8DjwIHs6+HI2JXtYNJWiupXVJ7V1dXztLMzKxS43ANJG0D5lfZtb5yJSJCUlR5/buB9wKXZpu2Sro+Iv5+cNuI2ABsAGhrazvrWGZmNnbDBn5ErBxqn6TDkhZExEFJC4DOKs1+C9gREcez1zwEfBg4K/DNzGz85J3S2QysyZbXAA9WafML4AZJjZKaKF+wrTqlY2Zm4ydv4N8DfELSHmBlto6kNkn3ZW02AXuBZ4FngGci4u9y9mtmZqM07JTOuUREN3Bjle3twOey5T7g83n6MTOz/PyXtmZmiXDgm5klwoFvZpYIB76ZWSIc+GZmiXDgm5klwoFvZpYIB76ZWSIc+GZmiXDgm5klwoFvZpYIB76ZWSIc+GZmiXDgm5klwoFvZpYIB76ZWSIc+GZmicj1xCszMyvOsweOsvnpA/T1B5+8egEfXDS70OPnCnxJnwa+CrwXWJE92rBau5uBvwAagPsi4p48/ZqZ1Zu/2PYi/+WxvZzu7SeA7z/5Kqs/tJB/95tXFtZH3imd54DfBh4bqoGkBuBe4BZgOXCHpOU5+zUzqxuvdJ/gP//PvZw6009/QAS8daaP+598lecOHC2sn1yBHxG7ImL3MM1WAB0RsS8iTgP3A6vy9GtmVk+2v9BJVNne09vHT58/VFg/E3HR9hLg1Yr1/dm2s0haK6ldUntXV9cElGZmVnvNjSUadPb2BolJTQ2F9TNs4EvaJum5Kl+Fn6VHxIaIaIuIttbW1qIPb2Z2Xrpp+Xyiyil+Q0n8xtUXF9bPsBdtI2Jlzj4OAAsr1i/NtpmZGdA6vYU/+/Sv8sc/eoZSqXyq39cffPU3r+Sy2VMK62cibst8ElgmaQnloF8N/N4E9GtmdsH45NUXc9275/LIC530RfCx97yLudNaCu0j722ZvwX8J6AV+ImkpyPiJkkXU7798taI6JX0JeBhyrdlficiduau3Myszsyc0sxvf+DScTt+rsCPiAeAB6psfw24tWJ9C7AlT19mZpaPP1rBzCwRDnwzs0Q48M3MEuHANzNLhKLa3f7nAUldwCvj3M1c4Mg491EE11ks11m8C6XWFOpcFBFV/3L1vA38iSCpPSLaal3HcFxnsVxn8S6UWlOv01M6ZmaJcOCbmSUi9cDfUOsCRsh1Fst1Fu9CqTXpOpOewzczS0nqZ/hmZslw4JuZJaKuA1/SbElbJe3Jvs+q0uZjkp6u+Dol6fZs399Ieqli3zW1rDVr11dRz+aK7UskPSGpQ9IPJDXXqk5J10h6XNJOST+X9JmKfeM6ppJulrQ7G4d1Vfa3ZOPTkY3X4op9f5Jt3y3ppiLrGkOdfyTp+Wz8tktaVLGv6nugRnXeKamrop7PVexbk71P9khaU+M6v1lR44uS3qjYN5Hj+R1JnZKeG2K/JP1l9u/4uaQPVOzLP54RUbdfwNeBddnyOuBrw7SfDfwSmJKt/w3wqfOpVuD4ENt/CKzOlr8FfKFWdQKXA8uy5YuBg8DM8R5Tyh+/vRdYCjQDzwDLB7X5IvCtbHk18INseXnWvgVYkh2noYZ1fqziffiFgTrP9R6oUZ13An9V5bWzgX3Z91nZ8qxa1Tmo/b+i/DHtEzqeWV8fBT4APDfE/luBhwAB1wJPFDmedX2GT/lh6Ruz5Y3A7cO0/xTwUEScHM+ihjDaWt8mScDHgU1jef0oDVtnRLwYEXuy5deATsrPTBhvK4COiNgXEaeB+7N6K1XWvwm4MRu/VcD9EdETES8BHdnxalJnRDxa8T7cQflJcRNtJOM5lJuArRHxy4h4HdgK3Hye1HkH8P1xquWcIuIxyieVQ1kFfDfKdgAzJS2goPGs98CfFxEHs+VDwLxh2q/m7DfCv89+tfqmpGIfP/NOI611ksoPet8xMPUEzAHeiIjebH3IB8VPYJ0ASFpB+axrb8Xm8RrTS4BXK9arjcPbbbLxOkp5/Eby2omss9JdlM/6BlR7D4yHkdb5O9l/z02SBh5nel6OZzY1tgR4pGLzRI3nSAz1bylkPCfiEYfjStI2YH6VXesrVyIiJA15D2r2U/R9lJ/MNeBPKIdaM+X7Yv8tcHeNa10UEQckLQUekfQs5dAqTMFj+j1gTUT0Z5sLHdN6J+mzQBtwQ8Xms94DEbG3+hHG3d8B34+IHkmfp/zb08drVMtIrAY2RURfxbbzaTzH1QUf+HGOh6xLOixpQUQczMKn8xyH+l3ggYg4U3HsgTPZHkn/DfhyrWuNiAPZ932Sfga8H/gflH/1a8zOWnM9KL6IOiVdBPwEWJ/9ajpw7ELHdJADwMKK9WrjMNBmv6RGYAbQPcLXTmSdSFpJ+YfsDRHRM7B9iPfAeATUsHVGRHfF6n2Ur/EMvPbXBr32Z4VX+P/7Gul/u9XAv6zcMIHjORJD/VsKGc96n9LZDAxczV4DPHiOtmfN62WBNjBHfjtQ9cp6QYatVdKsgSkQSXOB64Dno3xV51HK1yCGfP0E1tlM+dGX342ITYP2jeeYPgksU/mOpWbK/3MPvuuisv5PAY9k47cZWK3yXTxLgGXA/ymwtlHVKen9wLeB2yKis2J71fdADetcULF6G7ArW34Y+PWs3lnAr/PO354ntM6s1isoX/B8vGLbRI7nSGwGfj+7W+da4Gh2klTMeE7U1elafFGem90O7AG2AbOz7W2UH7I+0G4x5Z+gpUGvfwR4lnIo/XdgWi1rBT6S1fNM9v2uitcvpRxQHcCPgJYa1vlZ4AzwdMXXNRMxppTvcniR8hna+mzb3ZSDE2BSNj4d2XgtrXjt+ux1u4Fbxvm9OVyd24DDFeO3ebj3QI3q/A/AzqyeR4ErKl77z7Nx7gD+oJZ1ZutfBe4Z9LqJHs/vU75r7Qzlefi7gD8E/jDbL+De7N/xLNBW5Hj6oxXMzBJR71M6ZmaWceCbmSXCgW9mlggHvplZIhz4ZmaJcOCbmSXCgW9mloj/B/UQq88PXVO9AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "from data_loader.ccae_dataloader import  CCAE_Dataset\n",
    "from utils.plot import plot_concellation\n",
    "\n",
    "def test_ccae_dataset():\n",
    "    dataset = CCAE_Dataset()  #which_patterns='all')\n",
    "    plot_concellation(dataset[0])\n",
    "test_ccae_dataset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([4, 7, 2])\n"
     ]
    }
   ],
   "source": [
    "for i, x in enumerate(data_loader):\n",
    "    print(x['corners'].shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [4, 16]], which is output 0 of SliceBackward, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-9-f944a25a010d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m         \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m         \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     15\u001b[0m         \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     16\u001b[0m         \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/miniconda3/envs/scae/lib/python3.7/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m    219\u001b[0m                 \u001b[0mretain_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    220\u001b[0m                 create_graph=create_graph)\n\u001b[0;32m--> 221\u001b[0;31m         \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    222\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    223\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/miniconda3/envs/scae/lib/python3.7/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m    130\u001b[0m     Variable._execution_engine.run_backward(\n\u001b[1;32m    131\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 132\u001b[0;31m         allow_unreachable=True)  # allow_unreachable flag\n\u001b[0m\u001b[1;32m    133\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    134\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [4, 16]], which is output 0 of SliceBackward, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True)."
     ]
    }
   ],
   "source": [
    "# Start training\n",
    "for epoch in range(num_epochs):\n",
    "    for i, data in enumerate(data_loader):\n",
    "        # Forward pass\n",
    "        x = data['corners']\n",
    "        x = x.to(device)\n",
    "        res_dict = model(x)\n",
    "        \n",
    "        loss =ccae_loss(res_dict, x)\n",
    "        \n",
    "        # Backprop and optimize\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        print(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
